import os
from langgraph.graph import START, END, StateGraph, MessagesState
from langgraph.checkpoint.memory import MemorySaver
from langchain_core.tools import tool
from langchain_community.chat_models.zhipuai import ChatZhipuAI
from typing import Literal
from typing_extensions import TypedDict

os.environ["ZHIPUAI_API_KEY"] = "97738d4998b8732d707daf91a2b1c56d.2y6VKEuOlidwHDpI"


@tool
def get_weather(city: str):
    """Get the weather for a specific city"""
    return f"It's sunny in {city}!"

raw_model = ChatZhipuAI(
    model="glm-4",
    temperature=0.95
)
model = raw_model.with_structured_output(get_weather)

class SubGraphState(MessagesState):
    city: str

def model_node(state: SubGraphState):
    result = model.invoke(state["messages"])
    return {"city": result["city"]}

def weather_node(state: SubGraphState):
    result = get_weather.invoke({"city": state["city"]})
    return {"messages": [{"role": "assistant", "content": result}]}

subgraph = StateGraph(SubGraphState)
subgraph.add_node(model_node)
subgraph.add_node(weather_node)

subgraph.add_edge(START, "model_node")
subgraph.add_edge("model_node", "weather_node")
subgraph.add_edge("weather_node", END)
subgraph = subgraph.compile(interrupt_before=["weather_node"])

memory = MemorySaver()

class RouterState(MessagesState):
    route: Literal["weather", "other"]

class Router(TypedDict):
    route: Literal["weather", "other"]

router_model = raw_model.with_structured_output(Router)

def router_node(state: RouterState):
    system_message = "Classify the incoming query as either about weather or not."
    messages = [{"role": "system", "content": system_message}] + state["messages"]
    route = router_model.invoke(messages)
    return {"route": route["route"]}

def normal_llm_node(state: RouterState):
    response = raw_model.invoke(state["messages"])
    return {"messages": [response]}

def route_after_prediction(state: RouterState) -> Literal["weather_graph", "normal_llm_node"]:
    if state["route"] == "weather":
        return "weather_graph"
    else:
        return "normal_llm_node"

graph = StateGraph(RouterState)
graph.add_node(router_node)
graph.add_node(normal_llm_node)
graph.add_node("weather_graph", subgraph)

graph.add_edge(START, "router_node")
graph.add_conditional_edges("router_node", route_after_prediction)
graph.add_edge("normal_llm_node", END)
graph.add_edge("weather_graph", END)
graph = graph.compile(checkpointer=memory)

config = {"configurable": {"thread_id": "1"}}
inputs = {"messages": [{"role": "user", "content": "hi!"}]}
for update in graph.stream(inputs, config=config, stream_mode="updates"):
    print(update)

print('-------------------------------------------------------------------------------------------')

config = {"configurable": {"thread_id": "2"}}
inputs = {"messages": [{"role": "user", "content": "what's the weather in sf"}]}
for update in graph.stream(inputs, config=config, stream_mode="updates"):
    print(update)

print('-------------------------------------------------------------------------------------------')

config = {"configurable": {"thread_id": "2"}}
inputs = {"messages": [{"role": "user", "content": "what's the weather in sf"}]}
for update in graph.stream(inputs, config=config, stream_mode="updates", subgraphs=True):
    print(update)

print(graph.get_state(config).next)
print(graph.get_state(config).tasks)
print(graph.get_state(config, subgraphs=True).next)
print(graph.get_state(config, subgraphs=True).tasks)

print('-------------------------------------------------------------------------------------------')

config = {"configurable": {"thread_id": "4"}}
inputs = {"messages": [{"role": "user", "content": "what's the weather in sf"}]}
for update in graph.stream(inputs, config=config, stream_mode="updates"):
    print(update)

state = graph.get_state(config, subgraphs=True)
print(state.values["messages"])
print(state.tasks[0].state.config)
graph.update_state(state.tasks[0].state.config, {"city": "la"})
for update in graph.stream(None, config=config, stream_mode="updates", subgraphs=True):
    print(update)